# Owner(s): ["module: dynamo"]

# ruff: noqa: TRY002

import enum
import itertools
import operator
import types
import unittest
import weakref
from collections import defaultdict, namedtuple, OrderedDict, UserDict
from collections.abc import Callable
from functools import partial
from typing import Any, NamedTuple

import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch._functorch.config
import torch.nn
import torch.utils.checkpoint
from torch._dynamo.exc import Unsupported
from torch._dynamo.testing import same
from torch._dynamo.utils import dict_items
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    make_dynamo_test,
    munge_exc,
    parametrize,
)
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test


class SimpleDict(dict):
    pass


class DummyUserDict(UserDict):
    pass


class FakeMapping:
    def __init__(self, value: Any) -> None:
        self._value = value
        self.keys = lambda: ["a", "b", "c"]  # not required to be a method

    def __getitem__(self, key: str) -> Any:
        return self._value


class DictTests(torch._dynamo.test_case.TestCase):
    def test_dict_subclass_instantiation(self):
        def fn(x):
            sd = SimpleDict(x=5)
            return sd["x"] * x

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

    def test_dict_subclass_local_mutation(self):
        def fn(x):
            sd = SimpleDict(x=5)
            z = sd["x"] * x
            sd["x"] = 10
            return z * sd["x"]

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

    def test_dict_contains_enum(self):
        class TensorDim(str, enum.Enum):
            DDP = "ddp"
            FSDP = "fsdp"
            CP = "cp"
            TP = "tp"

        class Foo(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                val = x.sin()
                if TensorDim.DDP in {"ddp"}:
                    val += x.cos()
                if "ddp" in {TensorDim.DDP}:
                    val += x.cos()
                return val

        inp = torch.randn(4, 4)
        mod = Foo()
        opt_f = torch.compile(mod, backend="eager", fullgraph=True)
        self.assertEqual(mod(inp), opt_f(inp))

    def test_dict_subclass_local_with_non_dict_method(self):
        # Checks that add_1 method is inlined
        class MethodDict(dict):
            def add_1(self, x):
                return x + 1

        def fn(x):
            sd = MethodDict(x=5)
            z = sd["x"] * x
            sd["x"] = 10
            return sd.add_1(z * sd["x"])

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

    def test_dict_contains(self):
        sd = dict()
        sd[2] = 5
        sd[4] = 10

        def fn(x):
            if 1 in sd:
                x = x * 2
            else:
                x = x * 3
            return x

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

        # Ensure a recompilation
        sd[1] = 15
        self.assertEqual(fn(x), opt_fn(x))

        # Ensure not recompilation because the traced program remains same here.
        sd[2] = 10
        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
            self.assertEqual(fn(x), opt_fn(x))

    def test_dict_subclass_methods_fallback_readonly(self):
        sd = SimpleDict()
        sd[2] = 5
        sd[4] = 10
        # check that regular attr accesses work well
        sd.attr = 4

        def fn(x):
            for value in sd.values():
                x = x * value
            for key in sd:
                x = x * key
            for k, v in sd.items():
                x = x * k
                x = x * v
            # for k in sd:
            #     x = x * k

            if 1 in sd:
                x = x * 2
            else:
                x = x * 3

            x = x * sd.get(2, 0)
            x = x * sd.get(3, 4)
            x = len(sd) * x
            x = x * sd.attr
            return x

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

        # Ensure a recompilation
        sd[6] = 15
        self.assertEqual(fn(x), opt_fn(x))

    def test_dict_subclass_instantiation_return(self):
        def fn(x):
            sd = SimpleDict(x=5 * x)
            sd["y"] = 10
            return sd

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        ref = fn(x)
        res = opt_fn(x)
        self.assertEqual(type(ref), type(res))
        self.assertEqual(ref["x"], res["x"])
        self.assertEqual(ref["y"], res["y"])

    def test_dict_subclass_methods_fallback_mutation(self):
        def fn(sd, x):
            for value in sd.values():
                x = x * value
            sd[6] = 14
            for key in sd:
                x = x * key
            for k, v in sd.items():
                x = x * k
                x = x * v
            # for k in sd:
            #     x = x * k

            if 1 in sd:
                x = x * 2
            else:
                x = x * 3

            x = x * sd.get(2, 0)
            x = x * sd.get(3, 4)
            x = len(sd) * x
            return x

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)

        sd1 = SimpleDict()
        sd1[2] = 5
        sd1[4] = 10

        sd2 = SimpleDict()
        sd2[2] = 5
        sd2[4] = 10
        self.assertTrue(sd1 == sd2)

        self.assertEqual(fn(sd1, x), opt_fn(sd2, x))
        self.assertTrue(sd1 == sd2)

    def test_dict_subclass_setitem(self):
        class SetItemDict(dict):
            def __setitem__(self, key, value):
                super().__setitem__(key, value + 1)

        def fn(x):
            sd = SetItemDict(x=5 * x)
            sd["y"] = 10
            return sd

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        ref = fn(x)
        res = opt_fn(x)
        self.assertEqual(type(ref), type(res))
        self.assertEqual(ref["x"], res["x"])
        self.assertEqual(ref["y"], res["y"])

    def test_custom_iter_dict(self):
        class ReversedDict(dict):
            def __iter__(self):
                return reversed(list(self.keys()))

        d = {
            "foo": 1,
            "bar": 2,
        }

        d = ReversedDict(d)

        @torch.compile(backend="eager")
        def fn(x, d):
            # Forces side effects attribute reapplication logic
            d.sample = 1
            d["baz"] = 4
            return x * d["foo"] * d["bar"]

        fn(torch.randn(4), d)
        # This is intentional because the dict is mutated, so we will have a recompilation.
        fn(torch.randn(4), d)
        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
            fn(torch.randn(4), d)

    def test_custom_keys_iter_dict(self):
        class ReversedDict(dict):
            def keys(self):
                return ["bar", "foo"]

        d = {
            "foo": 1,
            "bar": 2,
        }

        d = ReversedDict(d)

        @torch.compile(backend="eager")
        def fn(x, d):
            return x * d["foo"] * d["bar"]

        fn(torch.randn(4), d)
        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
            fn(torch.randn(4), d)

    def test_dict_guard_on_keys_order(self):
        d = {
            2: 4,
            3: 5,
        }

        cnts = torch._dynamo.testing.CompileCounter()

        def fn(x, d):
            for key, value in d.items():
                x = x * key + value
            return x

        opt_fn = torch.compile(fn, backend=cnts)
        opt_fn(torch.randn(4), d)
        opt_fn(torch.randn(4), d)
        # No recompilation
        self.assertEqual(cnts.frame_count, 1)

        # move 2 to the end
        d[2] = d.pop(2)

        x = torch.randn(4)
        res = opt_fn(x, d)
        # Check recompilation
        self.assertEqual(cnts.frame_count, 2)
        self.assertEqual(res, fn(x, d))

    def test_dict_guard_on_keys_order2(self):
        d = {
            2: 4,
            3: 5,
        }

        cnts = torch._dynamo.testing.CompileCounter()

        def fn(x, d):
            for key in d:
                value = d[key]
                x = x * key + value
            return x

        opt_fn = torch.compile(fn, backend=cnts)
        opt_fn(torch.randn(4), d)
        opt_fn(torch.randn(4), d)
        # No recompilation
        self.assertEqual(cnts.frame_count, 1)

        # move 2 to the end
        d[2] = d.pop(2)

        x = torch.randn(4)
        res = opt_fn(x, d)
        # Check recompilation
        self.assertEqual(cnts.frame_count, 2)
        self.assertEqual(res, fn(x, d))

    def test_ordered_dict_reordered_keys(self):
        d = OrderedDict()
        d[2] = 4
        d[3] = 5
        d.move_to_end(2)

        cnts = torch._dynamo.testing.CompileCounter()

        def fn(x, d):
            y = 0
            for idx, value in enumerate(d.values()):
                if idx == 0:
                    y += torch.sin(x * value)
                else:
                    y += torch.cos(x * value)
            return y

        opt_fn = torch.compile(fn, backend=cnts)
        x = torch.randn(4)
        self.assertEqual(opt_fn(x, d), fn(x, d))

    def test_ordered_dict_subclass_reordered_keys(self):
        class ODSubclass(OrderedDict):
            def keys(self):
                return super().keys()

        d = ODSubclass()
        d[2] = 4
        d[3] = 5
        d.move_to_end(2)

        cnts = torch._dynamo.testing.CompileCounter()

        def fn(x, d):
            y = 0
            for idx, value in enumerate(d.values()):
                if idx == 0:
                    y += torch.sin(x * value)
                else:
                    y += torch.cos(x * value)
            return y

        opt_fn = torch.compile(fn, backend=cnts)
        x = torch.randn(4)
        self.assertEqual(opt_fn(x, d), fn(x, d))

    def test_lazy_key_guarding(self):
        d = {"a": 2, "b": 3, "c": 5}

        def fn(x):
            return x * d["a"]

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)

        x = torch.randn(4)
        ref = fn(x)
        res = opt_fn(x)
        self.assertEqual(ref, res)

        # Since key c was not used, it should not lead to a recompilation
        d.pop("c")
        d["d"] = 10

        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
            ref = fn(x)
            res = opt_fn(x)
            self.assertEqual(ref, res)

    def test_lazy_key_non_const_guarding(self):
        d = {
            list: 2,
            dict: 3,
            OrderedDict: 5,
            namedtuple: 7,
        }

        def fn(x):
            return x * d[list]

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)

        x = torch.randn(4)
        ref = fn(x)
        res = opt_fn(x)
        self.assertEqual(ref, res)

        # Since key c was not used, it should not lead to a recompilation
        d.pop(dict)
        d[defaultdict] = 10

        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
            ref = fn(x)
            res = opt_fn(x)
            self.assertEqual(ref, res)

    def test_dict_mutation_side_effect(self):
        def fn(d):
            d["c"] = d["a"] + d.pop("b")
            return d

        args1 = {"a": torch.randn(10), "b": torch.randn(10)}
        args2 = dict(args1)
        assert fn(args1) is args1
        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch.compile(fn, backend=cnts)
        self.assertIs(opt_fn(args2), args2)
        self.assertTrue(same(args1, args2))
        self.assertEqual(cnts.frame_count, 1)
        self.assertEqual(cnts.op_count, 1)

    def test_dict_copy_alias(self):
        @torch.compile(backend="eager", fullgraph=True)
        def run(x, d0):
            d1 = d0.copy()
            d1[0] = 1
            return x + 1, d1

        d0 = {}
        res, d1 = run(torch.zeros(1), d0)
        self.assertTrue(same(res, torch.ones(1)))
        self.assertEqual(d0, {})
        self.assertEqual(d1, {0: 1})

    def test_dict_subclass_get_method(self):
        class dotdict(dict):
            """dot.notation access to dictionary attributes"""

            __getattr__ = dict.get
            __setattr__ = dict.__setitem__
            __delattr__ = dict.__delitem__

        config = dotdict({"a": 1, "b": 2})

        def fn(x):
            x2 = x * 2  # noqa: F841
            x3 = x * config.get("a", 3)
            return x3

        x = torch.randn(2)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

    def test_dict_order_keys(self):
        def fn(d):
            c = 0
            for v in d.values():
                c += v
            return c

        args1 = {}
        args1["a"] = torch.rand(10)
        args1["b"] = torch.rand(10)
        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch.compile(fn, backend=cnts)
        self.assertEqual(fn(args1), opt_fn(args1))
        self.assertEqual(cnts.frame_count, 1)
        self.assertEqual(cnts.op_count, 2)

        # A different order of keys recompiles
        args2 = {}
        args2["b"] = args1["b"]
        args2["a"] = args1["a"]
        self.assertEqual(fn(args2), opt_fn(args2))
        self.assertEqual(cnts.frame_count, 2)
        # Extra calls don't recompile
        self.assertEqual(cnts.frame_count, 2)

    def test_dict_namedtuple(self):
        def fn(d):
            if namedtuple in d:
                return d[3] * 2
            else:
                return d[3] * 3

        args1 = {namedtuple: None, 3: torch.randn(3)}
        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
        self.assertEqual(fn(args1), opt_fn(args1))
        self.assertEqual(cnts.frame_count, 1)
        # Test a failing namedtuple guard
        args2 = {2: None, 3: torch.randn(3)}
        self.assertEqual(fn(args2), opt_fn(args2))
        self.assertEqual(cnts.frame_count, 2)

    def test_dict_order_keys_tensors(self):
        def fn(d, x):
            return d[x] + 3

        args1 = {}
        x = torch.randn(10)
        y = torch.randn(10)
        z = torch.randn(10)
        args1[x] = y
        args1[3] = z

        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
        self.assertEqual(fn(args1, x), opt_fn(args1, x))
        self.assertEqual(cnts.frame_count, 1)

        # Calling again doesn't recompile (same id and key order)
        opt_fn(args1, x)
        self.assertEqual(cnts.frame_count, 1)
        args2 = {}
        args2[3] = z
        args2[x] = y

        # Different order recompiles
        self.assertEqual(fn(args2, x), opt_fn(args2, x))
        self.assertEqual(cnts.frame_count, 2)

    def test_dict_order_keys_modules(self):
        def fn(d, x):
            return d[x](torch.ones(2, 2))

        args1 = {}
        x = torch.nn.Linear(2, 2)
        y = torch.nn.Linear(2, 2)
        z = torch.nn.Linear(2, 2)
        args1[x] = y
        args1[3] = z

        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch.compile(fn, backend=cnts)
        self.assertEqual(fn(args1, x), opt_fn(args1, x))
        self.assertEqual(cnts.frame_count, 1)

        # Calling again doesn't recompile (same id and key order)
        opt_fn(args1, x)
        self.assertEqual(cnts.frame_count, 1)
        args2 = {}
        args2[3] = z
        args2[x] = y

        # Different order recompiles
        self.assertEqual(fn(args2, x), opt_fn(args2, x))
        self.assertEqual(cnts.frame_count, 2)

    def test_contains_dunder_dict(self):
        class UserDefined:
            def __init__(self) -> None:
                self.a = 3
                self.b = 5

            def run(self, x):
                if "a" in self.__dict__:
                    x = x * self.a
                if "b" in self.__dict__:
                    x = x * self.b
                self.c = 7
                if "c" in self.__dict__:
                    x = x * self.c
                return x * self.__dict__.get("a") * self.__dict__.get("z", 2)

        obj = UserDefined()

        def fn(x):
            return obj.run(x)

        x = torch.randn(4)
        ref = fn(x)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        res = opt_fn(x)
        self.assertEqual(ref, res)

    def test_contains_module_dunder_dict(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.foo = 1
                self.bar = 2
                self.baz = 3

            def forward(self, x):
                if "foo" in self.__dict__:
                    return x * self.bar
                return x * self.baz

        mod = MyModule()
        x = torch.randn(10)
        opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
        self.assertEqual(mod(x), opt_mod(x))

    def test_update_dunder_dict(self):
        class UserDefined:
            def run(self, x):
                self.__dict__["a"] = 10
                return x * self.a + self.__dict__["a"]

        obj1 = UserDefined()
        obj2 = UserDefined()

        def fn(x, obj):
            return obj.run(x)

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        ref = fn(x, obj1)
        res = opt_fn(x, obj2)
        self.assertEqual(ref, res)
        # Make sure only `a` is updated.
        self.assertEqual(obj1.__dict__, obj2.__dict__)

    def test_update_module_dunder_dict(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x):
                self.__dict__["a"] = 10
                return x * self.a + self.__dict__["a"]

        mod = MyModule()
        x = torch.randn(10)
        opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
        self.assertEqual(mod(x), opt_mod(x))

    def test_dict_reconstruct_keeps_original_order(self):
        def fn():
            modules = OrderedDict([("act", torch.nn.ReLU())])
            module_dict = torch.nn.ModuleDict(modules)

            next_modules = {"fc4": torch.nn.Linear(5, 6), "act3": torch.nn.Sigmoid()}
            modules.update(next_modules.items())
            module_dict.update(next_modules)
            return modules, module_dict

        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch.compile(fn, backend=cnts)
        modules, module_dict = opt_fn()

        self.assertEqual(len(module_dict), len(modules))
        for k1, m2 in zip(modules, module_dict.children()):
            self.assertTrue(modules[k1] is m2)

    # FIXME: see comment in torch/_dynamo/polyfills/__init__.py:mutable_mapping_update
    @unittest.expectedFailure
    def test_dict_construct_from_mapping_like(self):
        def fn(x):
            fm = FakeMapping(x)
            d = dict(fm, x=x)
            return d

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

    def test_dict_subclass_initialization_in_graph(self):
        for super_class in (
            OrderedDict,
            dict,
        ):

            class CustomDict(super_class):
                def __new__(cls, *args, **kwargs):
                    return super().__new__(cls, *args, **kwargs)

                def __init__(self, *args, **kwargs):
                    super().__init__(*args, **kwargs)

            def fn(x):
                c = CustomDict()
                c["key"] = x
                assert "key" in c
                return c["key"] + 1

            opt_fn = torch.compile(fn, backend="eager", fullgraph=True)

            x = torch.rand(4)
            self.assertEqual(fn(x), opt_fn(x))

    def test_dict_list_values(self):
        def inner_fn(args):
            return [x[1].shape for x in args]

        @torch.compile(backend="eager")
        def fn(tensors):
            return inner_fn(zip(itertools.count(), tensors["args"]))

        fn({"args": [torch.ones(5, 5), torch.ones(5, 6), torch.ones(5, 7)]})
        fn({"args": [torch.ones(5, 5)]})

    def test_dict_iter(self):
        class MyMod(torch.nn.Module):
            def forward(self, x):
                z = {"my": 1, "const": 2, "dict": 3, "variable": 4}
                tot = 0
                for key in z:
                    tot += z[key]

                return tot

        x = torch.tensor([0])
        model = MyMod()
        opt_model = torch.compile(model, backend="eager", fullgraph=True)
        y = opt_model(x)

        self.assertEqual(y, 10)

    def test_dict_subclass_contains(self):
        # pattern from huggingface
        class ClassInstantier(OrderedDict):
            pass

        @torch.compile(fullgraph=True, backend="eager")
        def f(x, d):
            if "key1" in d:
                x = x + 2
            if "key2" in d:
                x = x + 4
            x = x + 8
            return x

        result = f(torch.ones(8), ClassInstantier({"key1": torch.ones(8)}))
        self.assertTrue(same(result, torch.full([8], 11.0)))

        result = f(torch.ones(8), ClassInstantier({"key2": torch.ones(8)}))
        self.assertTrue(same(result, torch.full([8], 13.0)))

    def test_dict_tag_guard(self):
        class Foo:
            def __init__(self) -> None:
                self.scalar = 10

        def fn(d, x):
            return d["a"] * d["b"] * d["c"].scalar * x

        foo = Foo()

        d = {"a": 2, "b": 3, "c": foo}

        opt_fn = torch.compile(fn, backend="eager")
        inp = torch.randn(3, 3)
        self.assertEqual(fn(d, inp), opt_fn(d, inp))

        d["a"] = 4
        self.assertEqual(fn(d, inp), opt_fn(d, inp))

        # Check that recompilation happens
        foo.scalar = 12
        self.assertEqual(fn(d, inp), opt_fn(d, inp))

    def test_empty_dict_recompilation(self):
        def fn(d, x):
            if d:
                return torch.cos(x)
            return torch.sin(x)

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        self.assertEqual(fn({}, x), opt_fn({}, x))
        self.assertEqual(fn({"a": 1}, x), opt_fn({"a": 1}, x))

    def test_udf_dict_reconstruction(self):
        class MyDict(dict):
            pass

        def fn(x, klass):
            x = x * 2
            sc_dict = dict.__new__(klass)
            sc_dict["x"] = x
            if isinstance(sc_dict, MyDict):
                sc_dict.attr = 3
            return sc_dict

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        ref = fn(x, MyDict)
        res = opt_fn(x, MyDict)
        self.assertEqual(ref, res)
        self.assertTrue(isinstance(res, MyDict))
        self.assertEqual(ref.attr, res.attr)

        ref = fn(x, dict)
        res = opt_fn(x, dict)
        self.assertEqual(ref, res)
        self.assertTrue(isinstance(res, dict))

    def test_weakref_dict(self):
        states = weakref.WeakKeyDictionary()

        mod1 = torch.nn.Module()
        mod2 = torch.nn.Module()

        states[mod1] = 2
        states[mod2] = 3

        def fn(x):
            if mod1 in states:
                x = torch.sin(x)
            if mod2 in states:
                x = torch.cos(x)
            return x

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        self.assertEqual(fn(x), opt_fn(x))

    def test_construct_user_dict_and_return(self):
        def fn(x):
            return DummyUserDict({"a": x + 1})

        x = torch.randn(4)
        res = fn(x)
        self.assertEqual(res["a"], x + 1)

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(res["a"], opt_fn(x)["a"])

    def test_fn_id(self):
        def fn(x, f):
            d = {id(f): 3}
            return x * d[id(f)]

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)

        def nothing():
            pass

        f = nothing
        self.assertEqual(fn(x, f), opt_fn(x, f))

    def test_mapping_proxy_for_local(self):
        def fn(x):
            d = {"a": 2, "b": 3, "c": 5 * x}
            mp = types.MappingProxyType(d)
            y = torch.sin(x * mp["a"])
            for v in mp.values():
                y += torch.cos(x * v)
            return mp

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        ref = fn(x)
        res = opt_fn(x)
        self.assertEqual(ref, res)
        self.assertTrue(type(res) is types.MappingProxyType)

    def test_mapping_proxy_for_nonlocal(self):
        d = {"a": 2, "b": 3, "c": 5}

        def fn(x):
            mp = types.MappingProxyType(d)
            y = torch.sin(x * mp["a"])
            for v in mp.values():
                y += torch.cos(x * v)
            d["d"] = 4
            return mp

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        ref = fn(x)
        res = opt_fn(x)
        self.assertEqual(ref, res)
        self.assertTrue(type(res) is types.MappingProxyType)

        # check update to d is reflected in res
        d["e"] = 5
        self.assertEqual(d["e"], res["e"])

    def test_mapping_proxy_existing(self):
        d = {"a": 2, "b": 3, "c": 5}

        def fn(x, mp):
            y = torch.sin(x * mp["a"])
            for v in mp.values():
                y += torch.cos(x * v)
            if isinstance(mp, types.MappingProxyType):
                y *= 2
            return y

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        mp = types.MappingProxyType(d)
        ref = fn(x, mp)
        res = opt_fn(x, mp)
        self.assertEqual(ref, res)

        d["a"] = 3
        ref = fn(x, mp)
        res = opt_fn(x, mp)
        self.assertEqual(ref, res)

        d.pop("b")
        ref = fn(x, mp)
        res = opt_fn(x, mp)
        self.assertEqual(ref, res)

    def test_dict_construction_from_mapping_proxy(self):
        d = {"a": 2, "b": 3, "c": 5}

        def fn(x, mp):
            d = dict(mp)
            y = torch.sin(x * d["a"])
            return y

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        mp = types.MappingProxyType(d)
        ref = fn(x, mp)
        res = opt_fn(x, mp)
        self.assertEqual(ref, res)

    def test_mapping_proxy_existing_mutation(self):
        d = {"a": 2, "b": 3, "c": 5}

        mp = types.MappingProxyType(d)

        def fn(x):
            d["d"] = 4
            y = torch.sin(x * mp["d"])
            return y

        opt_fn = torch.compile(fn, backend="eager")
        x = torch.randn(4)
        ref = torch.sin(x * 4)
        res = opt_fn(x)
        self.assertEqual(ref, res)
        self.assertEqual(d.keys(), mp.keys())

    def test_mapping_proxy_existing_local_mutation(self):
        d = {"a": 2, "b": 3, "c": 5}

        mp = types.MappingProxyType(d)

        def fn(x):
            # Dynamo should not cause a graph break here because it knows that
            # the existing proxy can't point to this new dict
            other_dict = {}
            other_dict["d"] = 4
            y = torch.sin(x * mp["c"])
            return y

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        ref = torch.sin(x * mp["c"])
        res = opt_fn(x)
        self.assertEqual(ref, res)
        self.assertEqual(d.keys(), mp.keys())

    def test_move_to_end(self):
        def fn(x):
            d = OrderedDict({"a": torch.cos(x), "b": 3, "c": 5})
            d.move_to_end("a")
            return d

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        self.assertEqual(["b", "c", "a"], list(opt_fn(x).keys()))
        self.assertEqual(fn(x), opt_fn(x))

    def test_mapping_proxy_ban_muation_on_dict_realization(self):
        def fn(x):
            class Foo:
                b = 4

            d = dict(Foo.__dict__)
            y = torch.sin(x) * d["b"]
            # This should cause a graph break, because otherwise the
            # Foo.__dict__ will not be updated.
            Foo.bar = 3
            return Foo, y * Foo.__dict__["bar"]

        opt_fn = torch.compile(fn, backend="eager")
        x = torch.randn(4)
        foo1, ref = fn(x)
        foo2, res = opt_fn(x)
        self.assertEqual(ref, res)
        self.assertEqual(foo1.bar, foo2.bar)

    def test_overridden_get_item(self):
        class MyDict(dict):
            def __init__(self, *args, **kwargs):
                super().__init__(*args, **kwargs)
                self.calls = 0

            def __getitem__(self, key):
                self.calls += 1
                return super().__getitem__(key) + 1

        def fn(x, d):
            d["d"] = 4
            return x * d["a"] + d["b"] + d["c"] + d["d"]

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        d1 = MyDict({"a": 2, "b": 3, "c": 5})
        ref = fn(x, d1)

        d2 = MyDict({"a": 2, "b": 3, "c": 5})
        res = opt_fn(x, d2)
        self.assertEqual(ref, res)
        self.assertEqual(d1.calls, d2.calls)

    def test_items_type(self):
        def fn():
            d = dict({"a": 1, "b": "2", "c": torch.tensor(3)})  # noqa: C418
            return d.items()

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        ref = fn()
        res = opt_fn()
        self.assertEqual(ref, res)
        self.assertEqual(type(res), dict_items)

    def test_builtin_or_with_invalid_types(self):
        args = (
            1,  # int
            1.0,  # float
            "a",  # str
            (1, 2),  # tuple
            [1, 2],  # list
        )

        @torch.compile(backend="eager", fullgraph=True)
        def fn(b: Any):
            a = {"one": torch.ones(1)}
            return a | b

        for arg in args:
            with self.assertRaisesRegex(Unsupported, "Observed exception"):
                _ = fn(arg)

    def test_builtin_or_with_diff_keys(self):
        def f():
            a = {"one": torch.ones(1)}
            b = {"two": torch.ones(2)}
            return a, b, a | b, b | a, a.__or__(b), b.__or__(a)

        opt_f = torch.compile(f, backend="eager", fullgraph=True)
        self.assertEqual(f(), opt_f())

    def test_builtin_or_with_same_keys(self):
        def f():
            a = {"one": torch.ones(1), "two": torch.ones(2)}
            b = {"one": torch.ones(1), "three": torch.ones(3)}
            return a, b, a | b, b | a, a.__or__(b), b.__or__(a)

        opt_f = torch.compile(f, backend="eager", fullgraph=True)
        self.assertEqual(f(), opt_f())

    def test_builtin_ior_(self):
        def f():
            a = {"one": torch.ones(1)}
            b = {"two": torch.ones(2)}
            a |= b
            return a, b

        opt_f = torch.compile(f, backend="eager", fullgraph=True)
        self.assertEqual(f(), opt_f())

    def test_newly_constructed_default_dict(self):
        def f(x):
            d = defaultdict(list)
            d[0] = [
                42,
            ]
            return x + 1, d

        x = torch.ones(2)
        ref = f(x)
        res = torch.compile(f, backend="eager", fullgraph=True)(x)

        self.assertEqual(ref, res)

    def test_newly_constructed_default_dict_no_default_factory(self):
        def f1(x):
            d = defaultdict()
            try:
                d[1] += 42
            except KeyError:
                d[1] = 1
            return x + 1, d

        x = torch.ones(2)
        ref = f1(x)
        res = torch.compile(f1, backend="eager", fullgraph=True)(x)

        self.assertEqual(ref, res)

        def f2(x):
            d = defaultdict(None)
            try:
                d[1] += 42
            except KeyError:
                d[1] = 1
            return x + 1, d

        ref = f2(x)
        res = torch.compile(f2, backend="eager", fullgraph=True)(x)
        self.assertEqual(ref, res)

        def f3(x):
            d = defaultdict(None, {1: 10})
            d[1] += 42
            try:
                d[2] += 24
            except KeyError:
                d[2] = 1
            return x + 1, d

        ref = f3(x)
        res = torch.compile(f3, backend="eager", fullgraph=True)(x)
        self.assertEqual(ref, res)

    def test_newly_constructed_default_dict_with_dict(self):
        def f(x):
            d = dict([("a", 1), ("b", 2)], c=3)  # noqa: C406
            dd = defaultdict(list, d, d=4, e=5)
            dd["x"].append(42)
            return x + 1, d, dd

        x = torch.ones(2)
        ref = f(x)
        res = torch.compile(f, backend="eager", fullgraph=True)(x)

        self.assertEqual(ref, res)

    def test_iter_default_dict(self):
        def f(x):
            d = defaultdict(list)
            d[0] = 42
            for k in d:
                d[k] += 1
            return x + 1, d

        x = torch.ones(2)
        ref = f(x)
        res = torch.compile(f, backend="eager", fullgraph=True)(x)

        self.assertEqual(ref, res)

    @parametrize("op", ["or_", "and_", "xor", "sub"])
    def test_dict_keys_binop(self, op):
        op = getattr(operator, op)

        def f():
            a = {"one": torch.ones(1), "two": torch.ones(2)}
            b = {"one": torch.ones(1), "three": torch.ones(3)}
            return op(a.keys(), b.keys()), op(b.keys(), a.keys())

        opt_f = torch.compile(f, backend="eager", fullgraph=True)
        self.assertEqual(f(), opt_f())

    @parametrize("op", ["ior", "iand", "ixor", "isub"])
    def test_dict_keys_inplace_binop(self, op):
        op = getattr(operator, op)

        def f():
            a = {"one": torch.ones(1), "two": torch.ones(2)}.keys()
            b = {"one": torch.ones(1), "three": torch.ones(3)}.keys()
            c = {"one": torch.ones(1), "two": torch.ones(2)}.keys()
            a = op(a, b)
            b = op(b, c)
            return a, b

        opt_f = torch.compile(f, backend="eager", fullgraph=True)
        self.assertEqual(f(), opt_f())

    def test_range_as_dict_key(self):
        def fn(x):
            d = {range(5): x * 2, range(10, 15): x * 3}
            return d[range(0, 5, 1)] + d[range(10, 15)]

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

    def test_tuple_as_dict_key(self):
        def fn(x):
            d = {(1, 2): x * 2, (3, 4, 5): x * 3}
            return d[(1, 2)] + d[(3, 4, 5)]

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

    def test_enum_as_dict_key(self):
        class Color(enum.Enum):
            RED = 1
            GREEN = 2
            BLUE = 3

        def fn(x):
            d = {Color.RED: x * 2, Color.GREEN: x * 3, Color.BLUE: x * 4}
            return d[Color.RED] + d[Color.GREEN] + d[Color.BLUE]

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

    def test_intenum_as_dict_key(self):
        class Priority(enum.IntEnum):
            LOW = 1
            MEDIUM = 2
            HIGH = 3

        def fn(x):
            d = {Priority.LOW: x * 2, Priority.MEDIUM: x * 3, Priority.HIGH: x * 4}
            return d[Priority.LOW] + d[Priority.MEDIUM] + d[Priority.HIGH]

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

    def test_frozenset_as_dict_key(self):
        def fn(x):
            d = {frozenset([1, 2]): x * 2, frozenset([3, 4, 5]): x * 3}
            return d[frozenset([1, 2])] + d[frozenset([3, 4, 5])]

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

    def test_typing_union_as_dict_key(self):
        from typing import Union

        def fn(x):
            d = {Union[int, str]: x * 2, Union[float, bool]: x * 3}
            return d[Union[int, str]] + d[Union[float, bool]]

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

    def test_numpy_dtype_as_dict_key(self):
        import numpy as np

        def fn(x):
            d = {np.float32: x * 2, np.int64: x * 3, np.bool_: x * 4}
            return d[np.float32] + d[np.int64] + d[np.bool_]

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

    def test_method_wrapper_as_dict_key(self):
        add_method = list.__add__
        mul_method = list.__mul__

        def fn(x):
            # Method wrappers are the type of bound methods on built-in types
            d = {add_method: x * 2, mul_method: x * 3}
            return d[add_method] + d[mul_method]

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

    def test_torch_builtin_function_as_dict_key(self):
        def fn(x, y):
            # Using torch built-in functions as dictionary keys
            d = {torch.add: x * 2, torch.mul: y * 3, torch.sub: x + y}
            return d[torch.add] + d[torch.mul] + d[torch.sub]

        x = torch.randn(4)
        y = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x, y), opt_fn(x, y))

    def test_frozen_dataclass_as_dict_key(self):
        from dataclasses import dataclass

        @dataclass(frozen=True)
        class Point:
            x: int
            y: int

        def fn(tensor):
            p1 = Point(1, 2)
            p2 = Point(3, 4)
            d = {p1: tensor * 2, p2: tensor * 3}
            return d[Point(1, 2)] + d[Point(3, 4)]

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

    def test_list_as_dict_key_raises_typeerror(self):
        def fn(x):
            d = {[1, 2, 3]: x * 2}
            return d[[1, 2, 3]]

        x = torch.randn(4)

        # First check that eager execution raises TypeError
        with self.assertRaises(TypeError):
            fn(x)

        # Also check that compiled version raises TypeError
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        with self.assertRaisesRegex(Unsupported, "Observed exception"):
            opt_fn(x)

    def test_get_default_nowrap_functions_as_dict_key(self):
        def fn(x):
            # Get the set of default nowrap functions
            nowrap_funcs = torch.overrides.get_default_nowrap_functions()
            # Use the set as a dict key and search for Tensor.grad.__get__ in it
            d = {frozenset(nowrap_funcs): x * 2}
            # Check if Tensor.grad.__get__ is in the set
            if torch.Tensor.grad.__get__ in nowrap_funcs:
                return d[frozenset(nowrap_funcs)] + x
            return x

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))


instantiate_parametrized_tests(DictTests)


class DictGuardTests(LoggingTestCase):
    thetype = dict

    @make_logging_test(recompiles=True)
    def test_popitem(self, records):
        d = self.thetype()
        d[1] = 2
        d[3] = 4

        @torch.compile(backend="eager", fullgraph=True)
        def fn(x):
            k, v = d.popitem()
            if k == 3 and v == 4:
                return x.sin()
            return x.cos()

        x = torch.tensor(1.0)
        y = fn(x)
        # sanity check
        self.assertEqual(len(records), 0)
        self.assertEqual(y, x.sin())

        d[3] = 5
        y = fn(x)
        self.assertEqual(len(records), 1)
        self.assertEqual(y, x.cos())
        record = self.getRecord(records, "d")
        self.assertIn(
            """d[3] == 4""",
            munge_exc(record),
        )

    @make_logging_test(recompiles=True)
    def test_cmp_eq(self, records):
        @torch.compile(backend="eager", fullgraph=True)
        def fn(x, d1, d2):
            if d1 == d2:
                return x.sin()
            return x.cos()

        x = torch.tensor(1.0)
        d1 = self.thetype({1: 2, 3: 4})
        d2 = self.thetype({1: 2, 5: 6})
        y = fn(x, d1, d2)
        # sanity check
        self.assertEqual(len(records), 0)
        self.assertEqual(y, x.cos())

        y = fn(x, d1, d1)
        self.assertEqual(len(records), 1)
        self.assertEqual(y, x.sin())
        record = self.getRecord(records, "d2")
        self.assertIn(
            """list(dict.keys(d2))""",
            munge_exc(record.getMessage()),
        )

    @make_logging_test(recompiles=True)
    def test_cmp_ne(self, records):
        @torch.compile(backend="eager", fullgraph=True)
        def fn(x, d1, d2):
            if d1 == d2:
                return x.sin()
            return x.cos()

        x = torch.tensor(1.0)
        d1 = self.thetype({1: 2, 3: 4})
        d2 = self.thetype({1: 2, 5: 6})
        y = fn(x, d1, d2)
        # sanity check
        self.assertEqual(len(records), 0)
        self.assertEqual(y, x.cos())

        y = fn(x, d1, d1)
        self.assertEqual(len(records), 1)
        self.assertEqual(y, x.sin())
        record = self.getRecord(records, "d2")
        self.assertIn(
            """list(dict.keys(d2))""",
            munge_exc(record.getMessage()),
        )

    @make_logging_test(recompiles=True)
    def test_cmp_or(self, records):
        @torch.compile(backend="eager", fullgraph=True)
        def fn(x, d1, d2):
            d = d1 | d2
            if d.get(5, False):
                return x.sin()
            return x.cos()

        x = torch.tensor(1.0)
        d1 = self.thetype({1: 2, 3: 4})
        d2 = self.thetype({1: 2, 5: 6})
        y = fn(x, d1, d2)
        # sanity check
        self.assertEqual(len(records), 0)
        self.assertEqual(y, x.sin())

        y = fn(x, d1, d1)
        self.assertEqual(len(records), 1)
        self.assertEqual(y, x.cos())
        record = self.getRecord(records, "d2")
        self.assertIn(
            """KeyError on d2[5]""",
            munge_exc(record.getMessage()),
        )

    @make_logging_test(recompiles=True)
    def test_cmp_ior(self, records):
        @torch.compile(backend="eager", fullgraph=True)
        def fn(x, d1, d2):
            d2 |= d1
            if d2.get(3, False):
                return x.sin()
            return x.cos()

        x = torch.tensor(1.0)
        d1 = self.thetype({1: 2, 3: 4})
        d2 = self.thetype({1: 2, 5: 6})
        d3, d4 = d2.copy(), d2.copy()
        y = fn(x, d1, d2)
        # sanity check
        self.assertEqual(len(records), 0)
        self.assertEqual(y, x.sin())

        y = fn(x, d3, d4)
        self.assertEqual(len(records), 1)
        self.assertEqual(y, x.cos())
        record = self.getRecord(records, "d1")
        self.assertIn(
            """KeyError on d1[3]""",
            munge_exc(record.getMessage()),
        )


class DictMethodsTests(torch._dynamo.test_case.TestCase):
    thetype = dict

    # Methods:
    # + clear
    # + copy
    # + fromkeys
    # + get
    # + items
    # + keys
    # + pop
    # + popitem
    # + setdefault
    # + update
    # + values
    # BinOps:
    # ==, !=, |

    def setUp(self):
        self._prev_trace_unittest = torch._dynamo.config.enable_trace_unittest
        torch._dynamo.config.enable_trace_unittest = True
        super().setUp()

    def tearDown(self):
        torch._dynamo.config.enable_trace_unittest = self._prev_trace_unittest
        return super().tearDown()

    def assertEqual(self, x, y):
        self.assertTrue(x == y, f"Expected {x} to be equal to {y}")

    def assertNotEqual(self, x, y):
        self.assertFalse(x == y, f"Expected {x} to not be equal to {y}")

    @make_dynamo_test
    def test_cmp_eq(self):
        d1 = self.thetype({"a": 1, "b": 2})
        d2 = self.thetype({"a": 1, "b": 2})
        d3 = self.thetype({"a": 1, "b": 3})
        self.assertEqual(d1, d2)
        self.assertNotEqual(d1, d3)

        # Test the == operator
        self.assertEqual(d1 == d2, True)
        self.assertEqual(d1 == d3, False)

        # Test the __eq__ method
        self.assertEqual(d1.__eq__(d2), True)
        self.assertEqual(d1.__eq__(d3), False)

        # Test Dict.__eq__
        self.assertEqual(dict.__eq__(d1, d2), True)
        self.assertEqual(self.thetype.__eq__(d1, d3), False)

    @make_dynamo_test
    def test_cmp_ne(self):
        d1 = self.thetype({"a": 1, "b": 2})
        d2 = self.thetype({"a": 1, "b": 2})
        d3 = self.thetype({"a": 1, "b": 3})
        self.assertNotEqual(d1, d3)
        self.assertEqual(d1, d2)

        # Test the != operator
        self.assertEqual(d1 != d3, True)
        self.assertEqual(d1 != d2, False)

        # Test the __ne__ method
        self.assertEqual(d1.__ne__(d3), True)
        self.assertEqual(d1.__ne__(d2), False)

        # Test Dict.__ne__
        self.assertEqual(dict.__ne__(d1, d3), True)
        self.assertEqual(self.thetype.__ne__(d1, d2), False)

    @make_dynamo_test
    def test_binop_or(self):
        d1 = self.thetype({"a": 1, "b": 2})
        d2 = self.thetype({"b": 3, "c": 4})

        # Test the | operator
        self.assertEqual(d1 | d2, {"a": 1, "b": 3, "c": 4})
        self.assertEqual(d2 | d1, {"a": 1, "b": 2, "c": 4})

        # Test the __or__ method
        self.assertEqual(d1.__or__(d2), {"a": 1, "b": 3, "c": 4})
        self.assertEqual(d2.__or__(d1), {"a": 1, "b": 2, "c": 4})

        # Test Dict.__or__
        self.assertEqual(dict.__or__(d1, d2), {"a": 1, "b": 3, "c": 4})
        self.assertEqual(self.thetype.__or__(d2, d1), {"a": 1, "b": 2, "c": 4})

        # Test with non-dict types
        self.assertRaises(TypeError, lambda: d1 | 1)

    @make_dynamo_test
    def test_binop_ior(self):
        d1 = self.thetype({"a": 1, "b": 2})
        d2 = self.thetype({"b": 3, "c": 4})

        # Test the |= operator
        d3, d4 = d1.copy(), d2.copy()
        d3 |= d2
        d4 |= d1
        self.assertEqual(d3, {"a": 1, "b": 3, "c": 4})
        self.assertEqual(d4, {"a": 1, "b": 2, "c": 4})

        # Test with an iterable
        d3, d4 = d1.copy(), d2.copy()

        # Test the __ior__ method
        d3, d4 = d1.copy(), d2.copy()
        d3.__ior__(d2)
        d4.__ior__(d1)
        self.assertEqual(d3, {"a": 1, "b": 3, "c": 4})
        self.assertEqual(d4, {"a": 1, "b": 2, "c": 4})

        # Test Dict.__or__
        d3, d4 = d1.copy(), d2.copy()
        self.assertEqual(dict.__ior__(d3, d2), {"a": 1, "b": 3, "c": 4})
        self.assertEqual(self.thetype.__ior__(d4, d1), {"a": 1, "b": 2, "c": 4})

        # Test return value
        d3, d4 = d1.copy(), d2.copy()
        self.assertEqual(d3.__ior__(d2), {"a": 1, "b": 3, "c": 4})
        self.assertEqual(dict.__ior__(d4, d1), {"a": 1, "b": 2, "c": 4})

        # Test with non-dict types
        self.assertRaises(TypeError, lambda: dict.__ior__(d1, 1))

    @make_dynamo_test
    def test_binop_ior_iterable(self):
        d1 = self.thetype({"a": 1, "b": 2})
        d2 = self.thetype({"b": 3, "c": 4})
        d3, d4 = d1.copy(), d2.copy()

        def fn(d):
            yield from d.items()

        self.assertEqual(d3.__ior__(d2.items()), {"a": 1, "b": 3, "c": 4})
        self.assertEqual(d4.__ior__(fn(d1)), {"a": 1, "b": 2, "c": 4})

    @make_dynamo_test
    def test_clear(self):
        d = self.thetype({"a": 1, "b": 2})
        d.clear()
        self.assertEqual(d, {})

        # Test that clear returns None
        d = self.thetype({"a": 1, "b": 2})
        self.assertIsNone(d.clear())

        # Test Dict.clear
        d = self.thetype({"a": 1, "b": 2})
        dict.clear(d)
        self.assertEqual(d, {})

        d = self.thetype({"a": 1, "b": 2})
        self.thetype.clear(d)
        self.assertEqual(d, {})

        # Test invalid usage
        self.assertRaises(TypeError, d.clear, 1)

    @make_dynamo_test
    def test_copy(self):
        d = self.thetype({"a": 1, "b": 2})
        d2 = d.copy()
        self.assertEqual(d, d2)

        # Test that copy returns a new instance
        self.assertIsNot(d, d2)

        # Test Dict.copy
        self.assertEqual(dict.copy(d), d2)
        self.assertEqual(self.thetype.copy(d), d2)

        # Test invalid usage
        self.assertRaises(TypeError, d.copy, 1)

    @unittest.expectedFailure
    @make_dynamo_test
    def test_fromkeys(self):
        d = self.thetype.fromkeys(["a", "b"], 1)
        self.assertEqual(d, {"a": 1, "b": 1})
        p = self.thetype.fromkeys(["a", "b"], None)
        self.assertEqual(p, {"a": None, "b": None})

        # Test Dict.fromkeys
        d2 = self.thetype.fromkeys(["c", "d"], 2)
        self.assertEqual(d2, {"c": 2, "d": 2})

        # Test invalid usage
        self.assertRaises(TypeError, self.thetype.fromkeys)
        self.assertRaises(TypeError, self.thetype.fromkeys, 1, 2)

    @make_dynamo_test
    def test_get(self):
        d = self.thetype({"a": 1, "b": 2})
        self.assertEqual(d.get("a"), 1)
        self.assertEqual(d.get("c", 3), 3)
        self.assertIsNone(d.get("c"))

        # Test Dict.get
        self.assertEqual(dict.get(d, "b"), 2)
        self.assertEqual(self.thetype.get(d, "b"), 2)

        # Test invalid usage
        self.assertRaises(TypeError, d.get)

    @make_dynamo_test
    def test_items(self):
        d = self.thetype({"a": 1, "b": 2})
        items = d.items()
        self.assertEqual(set(items), {("a", 1), ("b", 2)})

        # Test Dict.items
        self.assertEqual(set(dict.items(d)), {("a", 1), ("b", 2)})
        self.assertEqual(set(self.thetype.items(d)), {("a", 1), ("b", 2)})

        # Test invalid usage
        self.assertRaises(TypeError, d.items, 1)

    @make_dynamo_test
    def test_keys(self):
        d = self.thetype({"a": 1, "b": 2})
        keys = d.keys()
        self.assertEqual(set(keys), {"a", "b"})

        # Test Dict.keys
        self.assertEqual(set(dict.keys(d)), {"a", "b"})
        self.assertEqual(set(self.thetype.keys(d)), {"a", "b"})

        # Test invalid usage
        self.assertRaises(TypeError, d.keys, 1)

    @make_dynamo_test
    def test_pop(self):
        d = self.thetype({"a": 1, "b": 2})
        self.assertEqual(d.pop("a"), 1)
        self.assertEqual(d, {"b": 2})
        self.assertIsNone(d.pop("c", None))

        # Test Dict.pop
        d = self.thetype({"a": 1, "b": 2})
        self.assertEqual(dict.pop(d, "b"), 2)
        self.assertEqual(self.thetype.pop(d, "a"), 1)

        # Test invalid usage
        self.assertRaises(KeyError, d.pop, "c")
        self.assertRaises(TypeError, d.pop)

    @make_dynamo_test
    def test_popitem(self):
        d = self.thetype({"a": 1})
        key, value = d.popitem()
        self.assertEqual(key, "a")
        self.assertEqual(value, 1)
        self.assertEqual(len(d), 0)
        # check LIFO
        d = self.thetype()
        d["a"] = 1
        d["b"] = 2
        self.assertEqual(d.popitem(), ("b", 2))

        # Test Dict.popitem
        d = self.thetype({"a": 1})
        key, value = dict.popitem(d)
        self.assertEqual(key, "a")
        self.assertEqual(value, 1)

        d = self.thetype({"a": 1})
        key, value = self.thetype.popitem(d)
        self.assertEqual(key, "a")
        self.assertEqual(value, 1)

        # Test invalid usage
        if self.thetype is not OrderedDict:
            # OrderedDict accepts a keyword arg
            self.assertRaises(TypeError, d.popitem, 1)

    @make_dynamo_test
    def test_setdefault(self):
        d = self.thetype({"a": 1, "b": 2})
        self.assertEqual(d.setdefault("a", 3), 1)
        self.assertEqual(d.setdefault("c", 3), 3)
        self.assertIsNone(d.setdefault("d"), None)
        self.assertEqual(d, {"a": 1, "b": 2, "c": 3, "d": None})

        # Test Dict.setdefault
        self.assertEqual(dict.setdefault(d, "f", 5), 5)
        self.assertEqual(self.thetype.setdefault(d, "e", 5), 5)

        # Test invalid usage
        self.assertRaises(TypeError, d.setdefault)
        self.assertRaises(TypeError, d.setdefault, [[]])

    @make_dynamo_test
    def test_update(self):
        d = self.thetype({"a": 1, "b": 2})
        d.update({"b": 3, "c": 4})
        self.assertEqual(d, {"a": 1, "b": 3, "c": 4})

        # Test with another dict
        d2 = self.thetype({"d": 5})
        d.update(d2)
        self.assertEqual(d, {"a": 1, "b": 3, "c": 4, "d": 5})

        # Test Dict.update
        d3 = self.thetype({"e": 6})
        dict.update(d, d3)
        self.assertEqual(d, {"a": 1, "b": 3, "c": 4, "d": 5, "e": 6})
        d4 = self.thetype({"f": 7})
        self.thetype.update(d, d4)
        self.assertEqual(d, {"a": 1, "b": 3, "c": 4, "d": 5, "e": 6, "f": 7})

        # Test with keyword arguments
        d.update(f=7, g=8)
        self.assertEqual(d, {"a": 1, "b": 3, "c": 4, "d": 5, "e": 6, "f": 7, "g": 8})

        # Test Dict.update with keyword arguments
        self.thetype.update(d, h=9, i=10)
        self.assertEqual(
            d, {"a": 1, "b": 3, "c": 4, "d": 5, "e": 6, "f": 7, "g": 8, "h": 9, "i": 10}
        )

        # Test invalid usage
        self.assertRaises(TypeError, d.update, 1)

    @make_dynamo_test
    def test_values(self):
        d = self.thetype({"a": 1, "b": 2})
        values = d.values()
        self.assertEqual(set(values), {1, 2})

        # Test Dict.values
        self.assertEqual(set(dict.values(d)), {1, 2})
        self.assertEqual(set(self.thetype.values(d)), {1, 2})

        # Test invalid usage
        self.assertRaises(TypeError, d.values, 1)

    @make_dynamo_test
    def test_type(self):
        d = self.thetype({"a": 1, "b": 2})
        self.assertIsInstance(d, self.thetype)
        self.assertIs(type(d), self.thetype)

    @make_dynamo_test
    def test_dict_type_comparison(self):
        types = (dict, OrderedDict, defaultdict)
        self.assertEqual(self.thetype, self.thetype)
        self.assertTrue(self.thetype is self.thetype)
        for other in types:
            if self.thetype == other:
                continue
            self.assertNotEqual(self.thetype, other)
            self.assertTrue(self.thetype is not other, f"{self.thetype=}, {other=}")

    @make_dynamo_test
    def test_dict___iter__(self):
        d = self.thetype({1: 2})
        it = d.__iter__()
        self.assertEqual(next(it), 1)

    def test_functools_partial_key(self):
        def gn(x, y):
            return x + y

        def fn(x):
            new_dict = {}
            new_gn1 = partial(gn, x=1)
            new_dict[new_gn1] = 5
            return x * new_dict[new_gn1]

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)

        ref = fn(x)
        res = opt_fn(x)
        self.assertTrue(same(ref, res))

    def test_namedtuple_functools(self):
        class Container(NamedTuple):
            partial_fn: Callable
            const: int

        def gn(x, y):
            return x + y

        def fn(x):
            new_dict = {}

            new_gn = partial(gn, x=1)
            key = Container(new_gn, 4)
            new_dict[key] = 5
            # Make another key that should hash to the same value
            key1 = Container(new_gn, 4)
            return x * new_dict[key1]

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)

        ref = fn(x)
        res = opt_fn(x)
        self.assertTrue(same(ref, res))

    def test_custom_object_as_dict_key(self):
        """Test that custom objects with __hash__ as dict keys are properly handled.

        This test verifies that when using custom objects with overridden __hash__
        and __eq__ as dictionary keys, two instances with the same hash and equality
        should be recognized as the same key.
        """

        class CustomKey:
            def __init__(self, value, name):
                self.value = value
                self.name = name

        def fn(x):
            d = {}
            # Create first instance
            key1 = CustomKey(42, "test")
            d[key1] = x * 2

            # Create second instance with same values - should hash to same value
            key2 = CustomKey(42, "test")
            d[key2] = x * 3  # This should overwrite the first value

            return d[key1] * d[key2]

        x = torch.randn(4)

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertTrue(same(opt_fn(x), fn(x)))

    def test_user_defined_object(self):
        class A:
            def __init__(self):
                self.x = {}
                REF[self] = {}

        REF = {}

        def f(a, x):
            REF[a]["foo"] = x
            return x + 1

        opt_f = torch.compile(f, backend="eager", fullgraph=True)

        x = torch.randn(4)
        self.assertTrue(same(f(A(), x), opt_f(A(), x)))


class DictSubclassMethodsTests(DictMethodsTests):
    thetype = SimpleDict


class OrderedDictMethodsTests(DictMethodsTests):
    thetype = OrderedDict

    # Methods:
    # - popitem - Inherited from DictMethodsTest
    # + move_to_end

    @make_dynamo_test
    def test_move_to_end(self):
        d = self.thetype.fromkeys("abcde")
        self.assertEqual("".join(d), "abcde")
        d.move_to_end("b")
        self.assertEqual("".join(d), "acdeb")

        # Test OrderedDict.move_to_end
        self.thetype.move_to_end(d, "a")
        self.assertEqual("".join(d), "cdeba")

        # Test last=False
        self.thetype.move_to_end(d, "a", last=False)
        self.assertEqual("".join(d), "acdeb")

        # Test KeyError
        self.assertRaises(KeyError, d.move_to_end, "f")

    def test_cmp_eq_order(self):
        a = self.thetype.fromkeys("abc")
        b = self.thetype.fromkeys("bca")
        self.assertFalse(a == b)

    @make_dynamo_test
    def test_binop_or_return_type(self):
        d1 = self.thetype({"a": 1, "b": 2})
        d2 = self.thetype({"b": 3, "c": 4})

        # Test return type
        self.assertIs(type(d1 | d2), OrderedDict)
        self.assertIs(type(dict(d1) | d2), OrderedDict)
        self.assertIs(type(d1 | dict(d2)), OrderedDict)

    @make_dynamo_test
    def test_binop_ior_return_type(self):
        d1 = self.thetype({"a": 1, "b": 2})
        d2 = self.thetype({"b": 3, "c": 4})

        # Test return type
        d3, d4 = d1.copy(), d2.copy()
        self.assertIs(type(d3.__ior__(d2)), OrderedDict)
        self.assertIs(type(dict.__ior__(d4, d2)), OrderedDict)
        self.assertIs(type(self.thetype.__ior__(d4, d2)), OrderedDict)

        d3, d4 = d1.copy(), d2.copy()
        self.assertIs(type(dict.__ior__(d3, dict(d2))), OrderedDict)
        self.assertIs(type(dict.__ior__(dict(d3), d2)), dict)
        self.assertIs(type(dict(d4).__ior__(d2)), dict)

    @make_dynamo_test
    def test_popitem_kwarg(self):
        d = self.thetype.fromkeys("abcdf")
        self.assertEqual(d.popitem(last=True), ("f", None))
        self.assertEqual(list(d), list("abcd"))
        self.assertEqual(d.popitem(last=False), ("a", None))
        self.assertEqual(list(d), list("bcd"))
        self.assertEqual(d.popitem(False), ("b", None))
        self.assertEqual(list(d), list("cd"))
        self.assertEqual(d.popitem(True), ("d", None))
        self.assertEqual(list(d), list("c"))


class OrderedDictSubclassOverload(torch._dynamo.test_case.TestCase):
    def setUp(self):
        self._prev_trace_unittest = torch._dynamo.config.enable_trace_unittest
        torch._dynamo.config.enable_trace_unittest = True
        super().setUp()

    def tearDown(self):
        torch._dynamo.config.enable_trace_unittest = self._prev_trace_unittest
        return super().tearDown()

    def assertEqual(self, x, y):
        self.assertTrue(x == y, f"Expected {x} to be equal to {y}")

    def assertNotEqual(self, x, y):
        self.assertFalse(x == y, f"Expected {x} to not be equal to {y}")

    class OrderedDictSubclass(OrderedDict):
        def get(self, key, default=None, /):
            return default

        def move_to_end(self, key, last=True, /):
            # change the behavior to something else
            self.pop(key)

    thetype = OrderedDictSubclass

    @make_dynamo_test
    def test_move_to_end(self):
        p = self.thetype({"a": 1, "b": 2, "c": 3})
        p.move_to_end("a")
        self.assertEqual(list(p.keys()), list("bc"))


if __name__ == "__main__":
    from torch._dynamo.test_case import run_tests

    run_tests()
